{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e52af24e",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fdeea0fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "from rdkit import Chem\n",
    "from rdkit import RDLogger\n",
    "RDLogger.DisableLog('rdApp.*')                                                                                                                                                       \n",
    "\n",
    "from rdkit.Chem import Draw\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from sklearn.metrics import roc_auc_score as ras\n",
    "from sklearn.metrics import mean_squared_error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f064887f",
   "metadata": {},
   "source": [
    "### Auglichem imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a6bcf959",
   "metadata": {},
   "outputs": [],
   "source": [
    "from auglichem.molecule import Compose, RandomAtomMask, RandomBondDelete, MotifRemoval\n",
    "from auglichem.molecule.data import MoleculeDatasetWrapper\n",
    "from auglichem.molecule.models import GCN, AttentiveFP, GINE, DeepGCN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7125df85",
   "metadata": {},
   "source": [
    "### Set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2887799f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using: ./data_download/clintox.csv.gz\n",
      "DATASET: ClinTox\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1484it [00:00, 8222.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating scaffolds...\n",
      "Generating scaffold 0/1477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating scaffold 1000/1477\n",
      "About to sort in scaffold sets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/clo/miniforge3/envs/auglichem/lib/python3.8/site-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "# Create transformation\n",
    "transform = Compose([\n",
    "    RandomAtomMask([0.1, 0.3]),\n",
    "    RandomBondDelete([0.1, 0.3]),\n",
    "    MotifRemoval()\n",
    "])\n",
    "transform = RandomAtomMask(0.1)\n",
    "\n",
    "# Initialize dataset object\n",
    "dataset = MoleculeDatasetWrapper(\"ClinTox\", data_path=\"./data_download\", transform=transform, batch_size=128)\n",
    "\n",
    "# Get train/valid/test splits as loaders\n",
    "train_loader, valid_loader, test_loader = dataset.get_data_loaders(\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "072b1301",
   "metadata": {},
   "source": [
    "### Initialize model with task from data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1f47e690",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get model\n",
    "num_outputs = len(dataset.labels.keys())\n",
    "model = AttentiveFP(task=dataset.task, output_dim=num_outputs)\n",
    "\n",
    "# Uncomment the following line to use GPU\n",
    "#model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c48a46f",
   "metadata": {},
   "source": [
    "### Initialize traning loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4981f4ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "if(dataset.task == 'classification'):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "elif(dataset.task == 'regression'):\n",
    "    criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e818fdf",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2855ac6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "10it [00:36,  3.65s/it]\n",
      "10it [00:30,  3.06s/it]\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(2):\n",
    "    for bn, data in tqdm(enumerate(train_loader)):\n",
    "        optimizer.zero_grad()\n",
    "        loss = 0.\n",
    "        \n",
    "        # Get prediction for all data\n",
    "        _, pred = model(data)\n",
    "        \n",
    "        # To use GPU, data must be cast to cuda\n",
    "        #_, pred = model(data.cuda())\n",
    "\n",
    "        for idx, t in enumerate(train_loader.dataset.target):\n",
    "            # Get indices where target has a value\n",
    "            good_idx = np.where(data.y[:,idx]!=-999999999)\n",
    "            \n",
    "            # When the data is placed on GPU, target must come back to CPU\n",
    "            #good_idx = np.where(data.y.cpu()[:,idx]!=-999999999)\n",
    "\n",
    "            # Prediction is handled differently for classification and regression\n",
    "            if(train_loader.dataset.task == 'classification'):\n",
    "                current_preds = pred[:,2*(idx):2*(idx+1)][good_idx]\n",
    "                current_labels = data.y[:,idx][good_idx]\n",
    "            elif(train_loader.dataset.task == 'regression'):\n",
    "                current_preds = pred[:,idx][good_idx]\n",
    "                current_labels = data.y[:,idx][good_idx]\n",
    "            \n",
    "            loss += criterion(current_preds, current_labels)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "772cfea1",
   "metadata": {},
   "source": [
    "### Test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c81f966f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_loader, validation=False):\n",
    "    set_str = \"VALIDATION\" if validation else \"TEST\"\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        # All targets we're evaluating\n",
    "        target_list = test_loader.dataset.target\n",
    "        \n",
    "        # Dictionaries to keep track of predictions and labels for all targets\n",
    "        all_preds = {target: [] for target in target_list}\n",
    "        all_labels = {target: [] for target in target_list}\n",
    "        \n",
    "        model.eval()\n",
    "        for data in test_loader:\n",
    "            # Get prediction for all data\n",
    "            _, pred = model(data)\n",
    "\n",
    "            # To use GPU, data must be cast to cuda\n",
    "            #_, pred = model(data.cuda())\n",
    "            \n",
    "            for idx, target in enumerate(target_list):\n",
    "                # Get indices where target has a value\n",
    "                good_idx = np.where(data.y[:,idx]!=-999999999)\n",
    "                \n",
    "                # When the data is placed on GPU, target must come back to CPU\n",
    "                #good_idx = np.where(data.y.cpu()[:,idx]!=-999999999)\n",
    "                \n",
    "                # Prediction is handled differently for classification and regression\n",
    "                if(train_loader.dataset.task == 'classification'):\n",
    "                    current_preds = pred[:,2*(idx):2*(idx+1)][good_idx][:,1]\n",
    "                    current_labels = data.y[:,idx][good_idx]\n",
    "                elif(train_loader.dataset.task == 'regression'):\n",
    "                    current_preds = pred[:,idx][good_idx]\n",
    "                    current_labels = data.y[:,idx][good_idx]\n",
    "                \n",
    "                # Save predictions and targets\n",
    "                all_preds[target].extend(list(current_preds.detach().cpu().numpy()))\n",
    "                all_labels[target].extend(list(current_labels.detach().cpu().numpy()))\n",
    "            \n",
    "        scores = {target: None for target in target_list}\n",
    "        for target in target_list:\n",
    "            if(test_loader.dataset.task == 'classification'):\n",
    "                scores[target] = ras(all_labels[target], all_preds[target])\n",
    "                print(\"{0} {1} ROC: {2:.5f}\".format(target, set_str, scores[target]))\n",
    "            elif(test_loader.dataset.task == 'regression'):\n",
    "                scores[target] = mean_squared_error(all_labels[target], all_preds[target],\n",
    "                                                    squared=False)\n",
    "                print(\"{0} {1} RMSE: {2:.5f}\".format(target, set_str, scores[target]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "092df20d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CT_TOX VALIDATION ROC: 0.47011\n",
      "FDA_APPROVED VALIDATION ROC: 0.30634\n",
      "CT_TOX TEST ROC: 0.50000\n",
      "FDA_APPROVED TEST ROC: 0.47002\n"
     ]
    }
   ],
   "source": [
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09bc37df",
   "metadata": {},
   "source": [
    "### Model saving/loading example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "47a4c105",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "torch.save(model.state_dict(), \"./saved_models/example_gcn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f6cafea4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CT_TOX TEST ROC: 0.51159\n",
      "FDA_APPROVED TEST ROC: 0.47002\n"
     ]
    }
   ],
   "source": [
    "# Instantiate new model and evaluate\n",
    "model = AttentiveFP(task=dataset.task, output_dim=num_outputs)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1a559e08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CT_TOX TEST ROC: 0.50000\n",
      "FDA_APPROVED TEST ROC: 0.47002\n"
     ]
    }
   ],
   "source": [
    "# Load saved model and evaluate\n",
    "model.load_state_dict(torch.load(\"./saved_models/example_gcn\"))\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d193cb74",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
